import torch

import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from Modules import *
from LinearModule_utils import *



#------------------------------------------------------------------#
#embedding structure [ signal, memory, position ] 
#I will introduce k blanks to separate different sub-sequences
#------------------------------------------------------------------#



#------------------------------------------------------------------#
#config contains the following important parameters: 
#config.signal_start : Start Index of current signal embeddings (0 always)
#config.signal_end : End Index of current signal
#self.memory_index : Start index of memorized embeddings (from a previous layer)
#config.memory_end : End Index of memorized embeddings (from a previous layer)
#config.position_start : Start index of one-hot position embeddings
#config.seq_length : Sequence length of the smaller model that we are trying to simulate
#config.blank_identifier : Index containing Identifiers for blank token
#config.num_blanks : Number of blanks to separate the sub-sequences
#config.num_attention_heads : Number of attention heads
#config.scale_embeddings : A scale to initialize different query, key matrices
#config.inner_lr : Inner learning rate to simulate sgd  
#------------------------------------------------------------------# 



#------------------------------------------------------------------##------------------------------------------------------------------##------------------------------------------------------------------# 
 #Input: Number of attention heads in the smaller model, din denotes the embedding dimension through the attention module of small model whose forward pass we are trying to simulate
#Output: 3 attention layers
class AttentionForward (nn.Module):
    def __init__ (self, config, din, num_attnt_heads, use_softmax, projection_matrix=None, separate_QK=False, memory_index=0):
        super(AttentionForward, self).__init__()
        
        assert use_softmax==False ,\
            "Currently I only use linear attention in this module"
        
        assert num_attnt_heads <= config.num_attention_heads,\
            "Number of attention heads should be at least the number of attention heads necessary to simulate"
        
        self.separate_QK = separate_QK
        if projection_matrix is not None:
            dout = projection_matrix.shape[1]
        else:
            if separate_QK: dout = 2*din
            else: dout = din    

        self.linear = LinearForward(config, din=din, dout=dout, use_softmax=use_softmax, projection_matrix=projection_matrix, memory_index=-1)
        
        self.key_linear = self.linear
        #LinearForward(config, din=din, dout=dout, use_softmax=use_softmax, projection_matrix=projection_matrix, memory_index=-1)
        self.value_linear = self.linear
        #LinearForward(config, din=din, dout=dout, use_softmax=use_softmax, projection_matrix=projection_matrix, memory_index=-1)
        
        #if separate_QK:
        #self.value_linear = LinearForward(config, din=din, dout=din, use_softmax=use_softmax, shift_top=2*din, memory_index=memory_index+2*din) 
        #if not separate_QK:
        #    self.key_linear = LinearForward(config, din=din, dout=din, use_softmax=use_softmax, shift_top=din, memory_index=memory_index+din)


        
        self.gates = Gates (config)
        
        self.din = din
        self.num_attnt_heads = num_attnt_heads
        self.config = config
        self.memory_index = memory_index


        head_dim = config.hidden_size // config.num_attention_heads
        basemodel_head_dim = din // num_attnt_heads  
        
        self.attnt_module = Attention (config, normalize=True, proj_conv2d=True, proj_conv_dim=head_dim, proj_transpose=True)
        
        assert din % head_dim == 0, \
               "a bug! 'din' should be divisible by head dimensions"
        
        num_partitions = din // head_dim
        
        assert num_attnt_heads % num_partitions == 0, \
               "Num of attention heads should be divisible by num of partitions"
        
        num_attnt_heads_per_partition = num_attnt_heads // num_partitions
        
        #--------------------------------#--------------------------------#
        #For all Attention heads on the embeddings
        #Query uses the first set of din coordinates and splits them among the first 'num_attnt_heads' attention heads
        #Key uses the second set of din coordinates and splits them among the first 'num_attnt_heads' attention heads
        #Value uses the third set of din coordinates and splits them among the first 'num_attnt_heads' attention heads
        #Key and Query of embeddings ignore the position dependence.
        #--------------------------------#--------------------------------#
        
        q_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
        for i in range(num_partitions):
            for j in range(num_attnt_heads_per_partition):
                q_attn_head[ :, i * num_attnt_heads_per_partition + j, i ] = 1.
        
        
        
        q_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim))
        for i in range(num_attnt_heads):
            partition = i % num_attnt_heads_per_partition
            q_attn[ i, :basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim)
            
        
        k_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
        for i in range(num_partitions):
            for j in range(num_attnt_heads_per_partition):
                k_attn_head[ :, i * num_attnt_heads_per_partition + j, i + num_partitions] = 1.
        
         
        
        k_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim))
        for i in range(num_attnt_heads):
            partition = i % num_attnt_heads_per_partition
            k_attn[ i, :basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim)
            

        
        v_attn_head = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
        for i in range(num_partitions):
            for j in range(num_attnt_heads_per_partition):
                v_attn_head[ :, i * num_attnt_heads_per_partition + j, i + 2 * num_partitions] = 1.
        
        
        v_attn = torch.zeros((config.num_attention_heads, head_dim, head_dim))
        for i in range(num_attnt_heads):
            partition = i % num_attnt_heads_per_partition
            v_attn[ i, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim, partition*basemodel_head_dim: (partition + 1)*basemodel_head_dim ] = torch.eye(basemodel_head_dim)
            

        #c_attn_init, c_attn_bias = torch.cat([query, key, value], axis=0), torch.zeros(5 * config.hidden_size)


        #--------------------------------#--------------------------------#
        #For all Attention heads on the positions
        #Query, Key are set such that we never attend to the blank tokens!
        #--------------------------------#--------------------------------#
        
        #--------------------------------#--------------------------------#
        #The projection matrix takes the output of the attention heads, which has the required signal only in its first basemodel_head_dim coordiantes
        #We merge them together and return them at the head of the embedding
        #--------------------------------#--------------------------------#
        c_proj_init = torch.zeros((head_dim, config.num_attention_heads, config.num_attention_heads))
        for i in  range(num_partitions):
            c_proj_init[:, i, i*num_attnt_heads_per_partition: (i+1)*num_attnt_heads_per_partition] = 1.
        
        
         
        self.attnt_module.initialize_weights(q_attn_init=q_attn,\
                                             q_attn_init_head=q_attn_head,\
                                             k_attn_init=k_attn,\
                                             k_attn_init_head=k_attn_head,\
                                             v_attn_init=v_attn,\
                                             v_attn_init_head=v_attn_head,\
                                             c_proj_init=c_proj_init )


        #Initialize Gates
        #Ignore the changes for the blanks!
        #w, u, v, w_bias, u_bias, v_bias
        w = torch.zeros((1, 2*config.hidden_size))
        u = torch.zeros((1, 2*config.hidden_size))
        v = torch.zeros((1, 2*config.position_dim))
        w_bias = torch.zeros(2)
        u_bias = torch.zeros(2)
        v_bias = torch.zeros(2)

        #Input Gate is 1 on blanks and 0 for non-blanks
        v [0, config.seq_length: config.position_dim] = config.gate_scale * torch.ones(config.num_blanks)

        #Change Gate is 0 on blanks and 1 for non-blanks
        v [0, config.position_dim+config.seq_length: 2*config.position_dim] = - config.gate_scale * torch.ones(config.num_blanks)
        v_bias [1] += config.gate_scale

        self.gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias)

        #self.add_module('Attentionforward_Linearforward', self.linear)
        #self.add_module('Attentionforward_attention', self.attnt_module)
        #self.add_module('Attentionforward_gates', self.gates)
        
#Compute Q^{\top} K \sum_{j} a_{i, j} \nabla y_i^{\top} x_j x_j - Q^{\top} K  \nabla y_i^{\top} y_i y_i
# + K^{\top} Q \sum_{j} a_{j, i} \nabla y_j^{\top} x_i x_j - K^{\top} Q \sum_{j} a_{j, i} \nabla y_j^{\top} y_j x_j
# + \sum_{j} a_{j, i} \nabla y_j
       
    def forward(self, hidden_states, position_states, key_weights=None, value_weights=None, icl_mask=None):
        
        linear_output = self.linear.forward(hidden_states, position_states)
        #if self.separate_QK:
            #inp_hidden_states = hidden_states.clone()
            #this is an in-place operation, hence I need to do clone
            #inp_hidden_states[:, :self.config.num_blanks] += ( value_weights - inp_hidden_states[:, :self.config.num_blanks] )
        
        if not self.separate_QK:
            inp_hidden_states = torch.cat( [key_weights, hidden_states[:, self.config.num_blanks:] ], axis=1)            
            #key_out = self.key_linear(inp_hidden_states, position_states)
            key_out = self.key_linear(inp_hidden_states, position_states)
            assert torch.sum(linear_output[:, self.config.num_blanks:, self.din:]).item() < 1e-10,\
                   "Key portion not empty!"
            linear_output[:, self.config.num_blanks:, self.din:] += key_out[:, self.config.num_blanks:, :-self.din]

        
        
        inp_hidden_states = torch.cat( [value_weights, hidden_states[:, self.config.num_blanks:] ], axis=1)            
        #value_out = self.value_linear(inp_hidden_states, position_states)
        value_out = self.value_linear(inp_hidden_states, position_states)
        
        assert torch.sum(linear_output[:, self.config.num_blanks:, 2*self.din:]).item() < 1e-10,\
                "Value portion not empty!"
        linear_output[:, self.config.num_blanks:, 2*self.din:] += value_out[:, self.config.num_blanks:, :-2*self.din]

        #Send a mask such that the tokens don't attend to the blank tokens
        normalization_mask = torch.zeros( (1, 1, len(hidden_states[0]),  len(hidden_states[0]) ) )
        normalization_mask[:, :, :, :self.config.num_blanks] = torch.finfo(self.attnt_module.p_attn.weight.dtype).min
        
        #icl_mask needs to be a 3D tensor of shape (batch_size, seqlen, seqlen)
        #icl_mask[i, j] = 1 if token i tends to token j
        
        if icl_mask is not None:
            
            bt = icl_mask.shape[0]
            for i in range( bt ):
                sq1 = icl_mask[i].shape[0]
                sq2 = icl_mask[i].shape[1]
                nb  = self.config.num_blanks
            
                normalization_mask[i, :, nb: nb+sq1, nb: nb+sq2] = torch.tril( torch.round(torch.clip(1. - icl_mask[i], 0., 1.)) ) * torch.finfo(self.attnt_module.p_attn.weight.dtype).min
                
        #print ("------Attention-------")
        attnt_output  = self.attnt_module.forward(linear_output, position_states, normalization_mask=normalization_mask) [0]
        
        if self.memory_index != -1:
            #keep Qx, Kx in memory!
            #Keep also x separately afterwards!
            assert torch.sum(attnt_output[:, self.config.num_blanks:, self.memory_index:]).item() < 1e-10,\
                   "Memory portion not empty!"
            
            attnt_output[:, self.config.num_blanks:, self.memory_index: self.memory_index+2*self.din] += linear_output[:, self.config.num_blanks:, :2*self.din]
            attnt_output[:, self.config.num_blanks:, self.memory_index+2*self.din: self.memory_index+3*self.din] += hidden_states[:, self.config.num_blanks:, :self.din]
        
        gate_output   = self.gates.forward(linear_output, attnt_output, position_states)
        
        return gate_output
    
    
    
    
class AttentionBackward(nn.Module):

    def __init__ (self, config, din, num_attnt_heads, use_softmax, retain_nablay=False, memory_index=0):
        super(AttentionBackward, self).__init__()
        
        assert use_softmax==False ,\
            "Currently I only use linear attention in this module"
        
        
        
        
        self.first_attnt_module = Attention (config, peak_into_future=True, normalize=use_softmax)
        self.first_attnt_gates = Gates (config)
        self.first_mlp = MLP(config.hidden_size, config)
        self.first_mlp_gates = Gates (config)
        self.second_mlp = MLP(config.hidden_size, config)
        self.second_mlp_gates = Gates (config)
        self.second_attnt_module = Attention (config, peak_into_future=True, normalize=use_softmax)
        self.second_attnt_gates = Gates (config)
        self.backward = LinearBackward(config, din=din, dout=3*din, use_softmax=False, retain_nablay=retain_nablay) 
        self.retain_nablay = retain_nablay
        
        ##### First attention module #######
        ########### Assumption #############
            #The memory part has the following format [y_i, x_i^(query), x_i^(key), x_i^(value), \{ a^h_ij \} ]; a^h_ij is the set of all attention scores for all heads
            #First attention module computes \nabla y_i^\top x_j^(value) for each j at position i and stores after \nabla y_i
            #Furthermore, we also compute \sum_j a_{j, i} \nabla y_j
        ########### Assumption #############
        
        
        
        head_dim = config.hidden_size // config.num_attention_heads
        basemodel_head_dim = din // num_attnt_heads     
        self.memory_index = memory_index
        
        assert self.memory_index <= config.hidden_size - ( 4 * din + config.seq_length * num_attnt_heads ), \
            "Not enough memory to simulate backward pass"
        assert config.seq_length <= head_dim ,\
            "Currently I assume the head dimension is atleast the sequence length of original model"
        
        #--------------------------------#--------------------------------#
        #For all Attention heads on the embeddings

        #Query uses the first set of din coordinates and splits them among the first 'num_attnt_heads' attention heads
        #Key uses the set of din coordinates that contains x_i^(value) in the memory and splits them among the first 'num_attnt_heads' attention heads
        #Value is all zeros for the first 'num_attnt_heads' attention heads
        #They only focus on signal (ignoers position)


        #for the second 'num_attnt_heads' attention heads
        #Query is all zeros
        #Key uses the set of attention coordinates in the memory and splits them among the first 'num_attnt_heads' attention heads
        #Value uses \nabla y and splits them among the first 'num_attnt_heads' attention heads
        #They only focus on cross dependence (signal-position)

        #--------------------------------#--------------------------------#
        query = torch.zeros((2*config.num_attention_heads, head_dim, config.hidden_size))
        #query [:num_attnt_heads, :basemodel_head_dim, :basemodel_head_dim] = torch.eye(basemodel_head_dim)
        for i in range (num_attnt_heads):
            query[i, :basemodel_head_dim, i*basemodel_head_dim: (i+1)*basemodel_head_dim] = torch.eye(basemodel_head_dim)

        
        query = query.view( (2*config.hidden_size, config.hidden_size) )
        
        key = torch.zeros((2*config.num_attention_heads, head_dim, config.hidden_size))
        #key [:num_attnt_heads, :basemodel_head_dim, :basemodel_head_dim] = torch.eye(basemodel_head_dim)
        for i in range (num_attnt_heads):
            key[i, :basemodel_head_dim, self.memory_index + 3 * din + i*basemodel_head_dim: self.memory_index + 3 * din + (i+1)*basemodel_head_dim] = torch.eye(basemodel_head_dim)
        
        
        
        for i in range (num_attnt_heads):
            key[config.num_attention_heads + num_attnt_heads + i, :config.seq_length, self.memory_index + 4 * din + i*config.seq_length: self.memory_index + 4 * din + (i+1)*config.seq_length] = torch.eye(config.seq_length)
        
        
        key = key.view( (2*config.hidden_size, config.hidden_size) )


        value = torch.zeros((config.num_attention_heads, head_dim, config.hidden_size))
        
        
        for i in range (num_attnt_heads):
            value[num_attnt_heads + i, :basemodel_head_dim,  i*basemodel_head_dim: (i+1)*basemodel_head_dim] = torch.eye(basemodel_head_dim)
        
        
        
        value = value.view( (config.hidden_size, config.hidden_size) )


        c_attn_init, c_attn_bias = torch.cat([query, key, value], axis=0), torch.zeros(5 * config.hidden_size)


        #--------------------------------#--------------------------------#
        #For all Attention heads on the positions
        #Query, Key are all zeros for the first 'num_attnt_heads' attention heads !
        #Value caries the one-hot position embeddings to the fore-front 

        #for the second 'num_attnt_heads' attention heads,
        #query looks at the set of one-hot encodings for the position of the original model
        #Key is all zeros
        #Value is all zeros
        #Only depend on cross dependence
        #--------------------------------#--------------------------------#
    
        query = torch.zeros((2*config.num_attention_heads, head_dim, config.position_dim))
        
        
        for i in range (num_attnt_heads, 2*num_attnt_heads):
            query[config.num_attention_heads + i, :config.seq_length, :config.seq_length] = torch.eye(config.seq_length)
          
        query = query.view((2*config.num_attention_heads*head_dim, config.position_dim))
        
        
        key = torch.zeros((2*config.num_attention_heads * head_dim, config.position_dim))
        
        value = torch.zeros((config.num_attention_heads, head_dim, config.position_dim))
        value[:num_attnt_heads, :config.seq_length, :config.seq_length] = torch.eye(config.seq_length)
        value = value.view( (config.hidden_size, config.position_dim) )


        p_attn_init, p_attn_bias = torch.cat([query, key, value], axis=0), torch.zeros(5 * config.hidden_size)
        
        #--------------------------------#--------------------------------#
        #The projection matrix after the attention module places the output of the attention module after \nabla y at each position i
        #--------------------------------#--------------------------------#
        c_proj_init, c_proj_bias = torch.zeros((config.hidden_size, config.hidden_size)), torch.zeros(config.hidden_size)
        for i in range(num_attnt_heads):
            for j in range(config.seq_length):
                c_proj_init [din + i * config.seq_length + j, i * head_dim + j] = 1.

        for i in range(din):
            c_proj_init [din + num_attnt_heads * config.seq_length + i, (num_attnt_heads  + (i // basemodel_head_dim)) * head_dim + i % basemodel_head_dim] = 1.



        self.first_attnt_module.initialize_weights(c_attn_init, \
                                        c_attn_bias, \
                                        p_attn_init, \
                                        p_attn_bias, \
                                        c_proj_init, \
                                        c_proj_bias)


        #Initialize the first attention Gates
        #Ignore the changes for the blanks!
        #w, u, v, w_bias, u_bias, v_bias
        w = torch.zeros((1, 2*config.hidden_size))
        u = torch.zeros((1, 2*config.hidden_size))
        v = torch.zeros((1, 2*config.position_dim))
        w_bias = torch.zeros(2)
        u_bias = torch.zeros(2)
        v_bias = torch.zeros(2)

        #Input Gate is 1
        v_bias [0] += config.gate_scale

        #Change Gate is 0 on blanks and 1 for non-blanks
        v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.position_dim-config.seq_length)
        v_bias [1] += config.gate_scale

        self.first_attnt_gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias)


        ################################## First MLP module ####################################
            #We compute \nabla y_i^\top y_i at each position i for each attention head
            #Recall our embedding currently looks at each position i like [ \nabla y_i, \{  \nabla y_j^\top x_i \}_j (for all heads), \sum_j a^h_ji y_j for all heads ]
            #The memory part has the following format [y_i, x_i^(query), x_i^(key), x_i^(value), \{ a^h_ij \} ]; a^h_ij is the set of all attention scores for all heads
        ##################################################################################      
        #c_fc_init, c_fc_bias, c_proj_init, c_proj_bias
        #After passing through c_fc_init, we will apply Gelu on top of the sub-embeddings that contain [\nabla y_i, y_i, \nabla y_i + y_i ], 
        ################################## MLP module ####################################
        #Now, we follow the same steps for \nabla y_i, y_i, \nabla y_i + y_i
        c_fc_init = torch.zeros(config.hidden_size, config.hidden_size)
        c_fc_bias = torch.zeros(config.hidden_size)

        #First, make sure that \nabla y_i + y_i is computed and stored in the dimensions [workspace + num_attnt_heads * config.seq_length, workspace + num_attnt_heads * config.seq_length + din]
        workspace = din + num_attnt_heads * config.seq_length + din
        first_index_left = workspace 
        first_index_right = first_index_left + din

        second_index_left = 0
        second_index_right = second_index_left + din

        c_fc_init[ first_index_left : first_index_right, second_index_left : second_index_right ] = 1./config.scale_embeddings * torch.eye(din)

        second_index_left = self.memory_index
        second_index_right = second_index_left + din

        c_fc_init[ first_index_left : first_index_right, second_index_left : second_index_right ] = 1./config.scale_embeddings * torch.eye(din)

        #Next, we make sure that \nabla y_i passes through
        second_index_left = 0
        second_index_right = second_index_left + din

        c_fc_init[ second_index_left : second_index_right, second_index_left : second_index_right ] = 1./config.scale_embeddings * torch.eye(din)

        #Next, we make sure that y_i passes through
        second_index_left = self.memory_index
        second_index_right = second_index_left + din

        c_fc_init[ second_index_left : second_index_right, second_index_left : second_index_right ] = 1./config.scale_embeddings * torch.eye(din)


        ###################################### c_proj_init ######################################
        c_proj_bias = torch.zeros(config.hidden_size)
        c_proj_init = torch.zeros(config.hidden_size, config.hidden_size)


        first_index_left =  workspace #+ num_attnt_heads * config.seq_length    
        second_index_left = workspace 
        for i in range(num_attnt_heads):
            c_proj_init[ first_index_left + i, second_index_left + i * basemodel_head_dim: second_index_left + (i+1) * basemodel_head_dim ] = np.sqrt(np.pi / 2) * (config.scale_embeddings ** 2) * torch.ones( (basemodel_head_dim, ) )
        

        second_index_left = 0
        for i in range(num_attnt_heads):
            c_proj_init[ first_index_left + i, second_index_left + i * basemodel_head_dim: second_index_left + (i+1) * basemodel_head_dim ] = -np.sqrt(np.pi / 2) * (config.scale_embeddings ** 2) * torch.ones( (basemodel_head_dim, ) )
        

        second_index_left = self.memory_index
        for i in range(num_attnt_heads):
            c_proj_init[ first_index_left + i, second_index_left + i * basemodel_head_dim: second_index_left + (i+1) * basemodel_head_dim ] = -np.sqrt(np.pi / 2) * (config.scale_embeddings ** 2) * torch.ones( (basemodel_head_dim, ) )
        

            

        self.first_mlp.initialize_weights(c_fc_init, c_fc_bias, c_proj_init, c_proj_bias)   


        #Initialize the first mlp Gates
        #Ignore the changes for the blanks!
        #w, u, v, w_bias, u_bias, v_bias
        w = torch.zeros((1, 2*config.hidden_size))
        u = torch.zeros((1, 2*config.hidden_size))
        v = torch.zeros((1, 2*config.position_dim))
        w_bias = torch.zeros(2)
        u_bias = torch.zeros(2)
        v_bias = torch.zeros(2)

        #Input Gate is 1
        v_bias [0] += config.gate_scale

        #Change Gate is 0 on blanks and 1 for non-blanks
        v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.position_dim-config.seq_length)
        v_bias [1] += config.gate_scale

        self.first_mlp_gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias)

        
        
        ################################## Second MLP module ####################################
            #computes (\nabla y_i^\top x_j  - \nabla y_i^\top y_i - 1) * a_ij
            #Recall our embedding currently looks at each position i like [ \nabla y_i, \{  a_ij \nabla y_i^\top x_j \}_j (for all heads), \sum_j a^h_ji \nabla y_j for all heads, \nabla y_i^\top y_i for all heads]
            #The memory part has the following format [y_i, x_i^(query), x_i^(key), x_i^(value), \{ a^h_ij \} ]; a^h_ij is the set of all attention scores for all heads
        ##################################################################################      
        #c_fc_init, c_fc_bias, c_proj_init, c_proj_bias
        #After passing through c_fc_init, we will apply Gelu on top of the sub-embeddings that contain [\nabla y_j^\top x_i  - \nabla y_i^\top y_i - 1, a_ij, \nabla y_j^\top x_i  - \nabla y_i^\top y_i - 1 + a_ij ], 
        #where the last column is computed and stored in the dimensions [workspace , workspace  + num_attnt_heads * config.seq_length ]

        ################################## MLP module ####################################


        workspace = 0
        c_fc_init = torch.zeros(config.hidden_size, config.hidden_size)
        c_fc_bias = torch.zeros(config.hidden_size)



        #First, make sure that \nabla y_i^\top x_j  - \nabla y_i^\top y_i - 1 + a_ij is computed and stored in the dimensions [workspace , workspace + num_attnt_heads * config.seq_length]
        first_index_left = workspace 
        first_index_right = first_index_left + num_attnt_heads * config.seq_length

        second_index_left = din 
        second_index_right = second_index_left + num_attnt_heads * config.seq_length

        c_fc_init[ first_index_left : first_index_right, second_index_left : second_index_right ] = 1./config.scale_embeddings * torch.eye(num_attnt_heads * config.seq_length)


        second_index_left = self.memory_index + 4 * din  
        second_index_right = second_index_left + num_attnt_heads * config.seq_length

        c_fc_init[ first_index_left : first_index_right, second_index_left : second_index_right ] = 1./config.scale_embeddings * torch.eye(num_attnt_heads * config.seq_length)
    

        second_index_left = din + num_attnt_heads * config.seq_length + din 
        second_index_right = second_index_left + num_attnt_heads

        for i in range(num_attnt_heads):
            c_fc_init[ first_index_left + i * config.seq_length : first_index_left + (i + 1) * config.seq_length , second_index_left + i ] = -1./config.scale_embeddings 

        c_fc_bias [ first_index_left : first_index_right ] -= 1./config.scale_embeddings    


        #Next, we make sure that a_ij passes through
        first_index_left = self.memory_index + 4 * din  
        first_index_right = first_index_left + num_attnt_heads * config.seq_length
        second_index_left = self.memory_index + 4 * din 
        second_index_right = second_index_left + num_attnt_heads * config.seq_length

        c_fc_init[ first_index_left : first_index_right, second_index_left : second_index_right ] = 1./config.scale_embeddings * torch.eye(num_attnt_heads * config.seq_length)

        #Next, we make sure that \nabla y_i^\top x_j  - \nabla y_i^\top y_i - 1 passes through
        first_index_left =  workspace + num_attnt_heads * config.seq_length
        first_index_right = first_index_left + num_attnt_heads * config.seq_length

        second_index_left = din
        second_index_right = second_index_left + num_attnt_heads * config.seq_length

        c_fc_init[ first_index_left : first_index_right, second_index_left : second_index_right ] = 1./config.scale_embeddings * torch.eye(num_attnt_heads * config.seq_length)

        second_index_left = din + num_attnt_heads * config.seq_length + din 
        second_index_right = second_index_left + num_attnt_heads

        for i in range(num_attnt_heads):
            c_fc_init[ first_index_left + i * config.seq_length : first_index_left + (i + 1) * config.seq_length , second_index_left + i ] = -1./config.scale_embeddings 

        c_fc_bias [ first_index_left : first_index_right ] -= 1./config.scale_embeddings    
        




        ###################################### c_proj_init ######################################
        #[\nabla y_i^\top x_j - \nabla y_i^\top y_i - 1, a_ij, \nabla y_i^\top x_j  - \nabla y_i^\top y_i - 1 + a_ij ]
        #c_proj_init will compute -Gelu(\nabla y_i^\top x_j  - \nabla y_i^\top y_i - 1) - Gelu(a_ij) + Gelu( \nabla y_i^\top x_j  - \nabla y_i^\top y_i - 1 + a_ij ) 
        #and store them in the dimensions where a_ij was before
        ###################################### c_proj_init ######################################
        
        c_proj_bias = torch.zeros(config.hidden_size)
        c_proj_init = torch.zeros(config.hidden_size, config.hidden_size)


        first_index_left =  self.memory_index + 4 * din 
        first_index_right = first_index_left + num_attnt_heads * config.seq_length

        second_index_left = workspace 
        second_index_right = second_index_left + num_attnt_heads * config.seq_length

    
        c_proj_init[ first_index_left: first_index_right, second_index_left: second_index_right ] = np.sqrt(np.pi / 2) * (config.scale_embeddings ** 2) * torch.eye( num_attnt_heads * config.seq_length)
       

        second_index_left  = self.memory_index + 4 * din 
        second_index_right = second_index_left + num_attnt_heads * config.seq_length
        
        c_proj_init[ first_index_left: first_index_right, second_index_left: second_index_right ] = -np.sqrt(np.pi / 2) * (config.scale_embeddings ** 2) * torch.eye( num_attnt_heads * config.seq_length)

        second_index_left = workspace + num_attnt_heads * config.seq_length
        second_index_right = second_index_left + num_attnt_heads * config.seq_length 
        
        #c_proj_init[ first_index_left: first_index_right, second_index_left: second_index_right ] = -config.scale_embeddings * torch.eye( num_attnt_heads * config.seq_length)
        c_proj_init[ first_index_left: first_index_right, second_index_left: second_index_right ] = -np.sqrt(np.pi / 2) * (config.scale_embeddings ** 2)  * torch.eye( num_attnt_heads * config.seq_length)

        self.second_mlp.initialize_weights(c_fc_init, c_fc_bias, c_proj_init, c_proj_bias)  

    
    
        #Initialize the second mlp Gates
        #Ignore the changes for the blanks!
        #w, u, v, w_bias, u_bias, v_bias
        w = torch.zeros((1, 2*config.hidden_size))
        u = torch.zeros((1, 2*config.hidden_size))
        v = torch.zeros((1, 2*config.position_dim))
        w_bias = torch.zeros(2)
        u_bias = torch.zeros(2)
        v_bias = torch.zeros(2)

        #Input Gate is 1
        v_bias [0] += config.gate_scale

        #Change Gate is 0 on blanks and 1 for non-blanks
        v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.position_dim-config.seq_length)
        v_bias [1] += config.gate_scale

        self.second_mlp_gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias)

        
        ##### Second attention module #######
        ########### Assumption #############
            #Recall our embedding currently looks at each position i like [ \nabla y_i, \{  a_ij \nabla y_j^\top x_i \}_j (for all heads), \sum_j a^h_ji \nabla y_j for all heads, \nabla y_i^\top y_i for all heads]
            #The memory part has the following format [y_i, x_i^(query), x_i^(key), x_i^(value), \{ a^h_ij * (\nabla y_j^\top x_i  - \nabla y_i^\top y_i) \} ]; a^h_ij is the set of all attention scores for all heads
        ########### Mechanism #############
        #First, set of 'num_attnt_head' heads : compute \nabla x_i^(query)
        #    - Query: Takes the embedding portion that contains a^h_ij * (\nabla y_j^\top x_i  - \nabla y_i^\top y_i) and splits among the attention heads
        #    - Key:   Takes one-hot position encoding of original model
        #    - Value: Takes the memory that contains x_i^(key)

        #First, set of 'num_attnt_head' heads : compute \nabla x_i^(key)
        #    - Query:   one-hot position encoding of original model
        #    - Key:     Takes the embedding portion that contains a^h_ij * (\nabla y_j^\top x_i  - \nabla y_i^\top y_i) and splits among the attention heads
        #    - Value:   Takes the memory that contains x_i^(query)


        #First, set of '1' head : compute \nabla x_i^(value)
        #    - Query:  Takes one-hot position encoding of original model
        #    - Key:    Takes one-hot position encoding of original model
        #    - Value:  Takes the embedding porition that contains \sum_{j, i} \nabla y_{j, i}

        ########### Mechanism #############



        #--------------------------------#--------------------------------#
        #For all Attention heads on the embeddings
        #for the first 'num_attnt_heads' attention heads !
        #    - Query : Takes the embedding portion that contains a^h_ij * (\nabla y_j^\top x_i  - \nabla y_i^\top y_i) and splits among the attention heads
        #    - Key   : 0s
        #    - Value : Takes the memory that contains x_i^(key)
        #Cross-dependence (signal-position) only

        #for the second 'num_attnt_heads' attention heads !
        #    - Query : 0s
        #    - Key   : Takes the embedding portion that contains a^h_ij * (\nabla y_j^\top x_i  - \nabla y_i^\top y_i) and splits among the attention heads
        #    - Value : Takes the memory that contains x_i^(query)
        #Cross-dependence (signal-position) only
   
        #for the final 1 attention head !
        #    - Query : 0s
        #    - Key   : 0s
        #    - Value : Takes the embedding porition that contains \sum_{j, i} \nabla y_{j, i}

        #--------------------------------#--------------------------------#

        query = torch.zeros((2*config.num_attention_heads, head_dim, config.hidden_size))
        
        for i in range (num_attnt_heads):
            query[config.num_attention_heads + i, :config.seq_length, self.memory_index + 4 * din + i*config.seq_length: self.memory_index + 4 * din + (i+1)*config.seq_length] = torch.eye(config.seq_length)
        
        query = query.view( (2*config.hidden_size, config.hidden_size) )
        



        key = torch.zeros((2*config.num_attention_heads, head_dim, config.hidden_size))
        for i in range (num_attnt_heads):
            key[config.num_attention_heads + num_attnt_heads + i, :config.seq_length, self.memory_index + 4 * din + i*config.seq_length: self.memory_index + 4 * din + (i+1)*config.seq_length] = torch.eye(config.seq_length)
    
        
        key = key.view( (2*config.hidden_size, config.hidden_size) )

        #split x_i^(key) among the heads
        value = torch.zeros((config.num_attention_heads, head_dim, config.hidden_size))
        for i in range (num_attnt_heads):
            value[i, :basemodel_head_dim,  self.memory_index + 2 * din + i*basemodel_head_dim: self.memory_index + 2 * din + (i+1)*basemodel_head_dim] = torch.eye(basemodel_head_dim)
        
        #split x_i^(query) among the heads
        for i in range (num_attnt_heads):
            value[num_attnt_heads + i, :basemodel_head_dim,  self.memory_index + din + i*basemodel_head_dim: self.memory_index + din + (i+1)*basemodel_head_dim] = torch.eye(basemodel_head_dim)

        
        value[2*num_attnt_heads, :din,  din + num_attnt_heads * config.seq_length: din + num_attnt_heads * config.seq_length + din] = torch.eye(din)

        value = value.view( (config.hidden_size, config.hidden_size) )


        c_attn_init, c_attn_bias = torch.cat([query, key, value], axis=0), torch.zeros(5 * config.hidden_size)


        #--------------------------------#--------------------------------#
        #For all Attention heads on the embeddings
        #for the first 'num_attnt_heads' attention heads !
        #    - Query : 0s
        #    - Key   : Takes one-hot position encoding of original model
        #    - Value : 0s
        # Cross depenence (Signal-position)

        #for the second 'num_attnt_heads' attention heads !
        #    - Query : Takes one-hot position encoding of original model
        #    - Key   : 0s
        #    - Value : 0s
        # Cross depenence (Signal-position)
   
        #for the final 1 attention head !
        #    - Query : Takes one-hot position encoding of original model
        #    - Key   : Takes one-hot position encoding of original model
        #    - Value : 0s 
        # Position-position
        #--------------------------------#--------------------------------#

        query = torch.zeros((2*config.num_attention_heads, head_dim, config.position_dim))
        #for i in range (num_attnt_heads, 2*num_attnt_heads):
        query[config.num_attention_heads + num_attnt_heads: config.num_attention_heads + 2 * num_attnt_heads, :config.seq_length, :config.seq_length] = torch.eye(config.seq_length)
        query[2*num_attnt_heads, :config.seq_length, :config.seq_length] = torch.eye(config.seq_length)

        query = query.view( (2*config.hidden_size, config.position_dim) )    

        key = torch.zeros((2*config.num_attention_heads, head_dim, config.position_dim))
        key[config.num_attention_heads: config.num_attention_heads + num_attnt_heads, :config.seq_length, :config.seq_length] = torch.eye(config.seq_length)



        key[2*num_attnt_heads, :config.seq_length, :config.seq_length] = torch.eye(config.seq_length)

        #for i in range (num_attnt_heads, 2*num_attnt_heads):
        #   key[i, :config.seq_length, :config.seq_length] = torch.eye(config.seq_length)
        key = key.view( (2*config.hidden_size, config.position_dim) )  

        value = torch.zeros((config.num_attention_heads, head_dim, config.position_dim)) 
        value = value.view( (config.hidden_size, config.position_dim) )


        p_attn_init, p_attn_bias = torch.cat([query, key, value], axis=0), torch.zeros(5 * config.hidden_size)
        
        #p_value_init, p_value_bias = value, torch.zeros(config.hidden_size)

        #--------------------------------#--------------------------------#
        #The projection matrix after the attention module places the output of the attention modules concatenated
        #--------------------------------#--------------------------------#
        c_proj_init, c_proj_bias = torch.zeros((config.hidden_size, config.hidden_size)), torch.zeros(config.hidden_size)
        for i in range(2*num_attnt_heads):
            for j in range(basemodel_head_dim):
                c_proj_init [i * basemodel_head_dim + j, i * head_dim + j] = 1.
                
        if self.retain_nablay:
            #we need to hold onto memory of query, key, value for a bit more time
            c_proj_init [self.memory_index: self.memory_index+3*din, self.memory_index+din: self.memory_index+4*din] = torch.eye(3*din)
            
        c_proj_init [2 * din: 3 * din, 2 * num_attnt_heads * head_dim : 2 * num_attnt_heads * head_dim + din] = torch.eye(din)

        self.second_attnt_module.initialize_weights(c_attn_init, \
                                        c_attn_bias, \
                                        p_attn_init, \
                                        p_attn_bias, \
                                        c_proj_init, \
                                        c_proj_bias)

        
        #Initialize the second attention Gates
        #Ignore the changes for the blanks!
        #w, u, v, w_bias, u_bias, v_bias
        w = torch.zeros((1, 2*config.hidden_size))
        u = torch.zeros((1, 2*config.hidden_size))
        v = torch.zeros((1, 2*config.position_dim))
        w_bias = torch.zeros(2)
        u_bias = torch.zeros(2)
        v_bias = torch.zeros(2)

        #Input Gate is 1 on blanks and 0 for non-blanks
        v [0, config.seq_length: config.position_dim] = config.gate_scale * torch.ones(config.position_dim-config.seq_length)

        #Change Gate is 0 on blanks and 1 for non-blanks
        v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.position_dim-config.seq_length)
        v_bias [1] += config.gate_scale

        self.second_attnt_gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias)
        
        
    def forward(self, hidden_states, position_states, attention_mask):
        
        attnt_output  = self.first_attnt_module.forward(hidden_states, position_states, attention_mask=attention_mask) [0]
        gate_output   = self.first_attnt_gates.forward(hidden_states, attnt_output, position_states) 
        
        
        mlp_output  = self.first_mlp.forward(gate_output)
        gate_output = self.first_mlp_gates.forward(gate_output, mlp_output, position_states)
        
        
        mlp_output  = self.second_mlp.forward(gate_output)
        gate_output = self.second_mlp_gates.forward(gate_output, mlp_output, position_states)
        
        
        attnt_output  = self.second_attnt_module.forward(gate_output, position_states, attention_mask=attention_mask) [0]
        gate_output   = self.second_attnt_gates.forward(gate_output, attnt_output, position_states) 
        
        
        
        linear_output = self.backward.forward(gate_output, position_states)
        
        return linear_output
        
        
#Take Gradient w.r.t. Q, K, V
class AttentionDescent(nn.Module):
    def __init__ (self, config, din, num_attnt_heads, use_softmax, memory_index=-1):
        super(AttentionDescent, self).__init__()
        self.linear = LinearDescent(config, din=din, dout=3*din, use_softmax=use_softmax, memory_index=memory_index) 

    def forward(self, hidden_states, position_states, attention_mask):   
        return self.linear.forward(hidden_states, position_states, attention_mask)